import torch
from torch.utils.data import Dataset, DataLoader

# 路径前面添加r的目的是告诉他路径是字符串，只需要在windows中使用，linux中不存在该问题。
data_path = r"data\SMSSpamCollection"

## 完成数据集类
class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path,encoding='utf-8').readlines()

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, index):
        ### 获取索引对应位置的一条数据
        cur_line = self.lines[index].strip()
        label = cur_line[:4].strip()
        content = cur_line[4:].strip()
        return label,content

my_dataset = MyDataset()

#### drop_last = True 意味舍弃最后一条不满足10条数据的数据
data_loader = DataLoader(dataset=my_dataset,batch_size=10,shuffle=True,drop_last=True)

if __name__ == '__main__':

    for i in data_loader:
        print(i)